Behavior of the median filter: Spike trains with different distance between spikes

DW (2015.11.30)


In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
from scipy.signal import medfilt, triang
import gitInformation

In [2]:
%matplotlib inline

In [3]:
gitInformation.printInformation()


Information about this notebook
============================================================
Date: 2016-01-21
Python Version: 2.7.10 |Anaconda 2.3.0 (64-bit)| (default, May 28 2015, 16:44:52) [MSC v.1500 64 bit (AMD64)]
Git directory: C:\Users\Dominik\Documents\GitRep\kt-2015-DSPHandsOn\.git
Current git SHA: e6d8cd8a76886c134c66e781ccd5c61afc7f9e75
Remotes: origin, 
Current branch: master
origin remote URL: https://github.com/dowa4213/kt-2015-DSPHandsOn.git

In [4]:
tri = np.array([0, 1, 2, 3, 4, 5, 4, 3, 2, 1,0])
tri2 = np.array([1, 2, 3, 4, 5, 4, 3, 2, 1, 0])
tri = tri/5.
tri2 = tri2/5.
plt.plot(tri)


Out[4]:
[<matplotlib.lines.Line2D at 0x16df1470>]

Generate 3 spike trains with different distance between the spikes


In [5]:
data1 = np.append(tri,tri2)
for i in range (10):
    data1 = np.append(data1,tri2)

In [6]:
x = np.zeros(5)
data2 = np.append(tri,x)
for i in range (10):
    data2 = np.append(data2,tri2)
    data2 = np.append(data2,x)

In [7]:
x = np.zeros(10)
data3 = np.append(tri,x)
for i in range (10):
    data3 = np.append(data3,tri2)
    data3 = np.append(data3,x)

In [8]:
x = np.zeros(15)
data4 = np.append(tri,x)
for i in range (10):
    data4 = np.append(data4,tri2)
    data4 = np.append(data4,x)

In [9]:
x = np.zeros(20)
data5 = np.append(tri,x)
for i in range (10):
    data5 = np.append(data5,tri2)
    data5 = np.append(data5,x)

In [10]:
x = np.zeros(25)
data6 = np.append(tri,x)
for i in range (10):
    data6 = np.append(data6,tri2)
    data6 = np.append(data6,x)

Plot spike trains


In [13]:
plt.figure(1, figsize=(30,7))
plt.axis([0, 100, 0, 1])
plt.plot(data1)
plt.figure(2, figsize=(30,7))
plt.axis([0, 100, 0, 1])
plt.plot(data2)
plt.figure(3, figsize=(30,7))
plt.axis([0, 100, 0, 1])
plt.plot(data3)
plt.savefig("Spiketrain.png", dpi = 400)
plt.figure(4, figsize=(30,7))
plt.axis([0, 100, 0, 1])
plt.plot(data4)
plt.figure(5, figsize=(30,7))
plt.axis([0, 100, 0, 1])
plt.plot(data5)
plt.figure(6, figsize=(30,7))
plt.axis([0, 100, 0, 1])
plt.plot(data6)


Out[13]:
[<matplotlib.lines.Line2D at 0x18dc6828>]

Calculate the median filtered wave


In [14]:
# Array with the different window lengths we use for the filter.
wl = np.array([3, 5, 11, 15, 17, 21, 25, 31, 35])

In [15]:
filtered1 = np.zeros((140, len(data1)))
filtered_spikes = np.zeros((140, len(data1)))
rms = np.zeros(140)
values1 = np.zeros(len(wl))
count = -1
for w in wl:
    count += 1
    for i in range(len(filtered1)):
        data1_noised = data1 + np.random.normal(0, 0.15, len(data1))
        filtered1[i,:] = medfilt(data1_noised, w)
        filtered_spikes[i,:] = data1 - filtered1[i,:]
        #rms[i] = np.sqrt(np.mean(np.square(filtered1[i,:])))
        rms[i] = np.sqrt(np.mean(np.square(filtered_spikes[i,:])))
    mean = np.mean(rms)
    values1[count] = mean

In [16]:
filtered2 = np.zeros((140, len(data2)))
filtered_spikes = np.zeros((140, len(data2)))
rms = np.zeros(140)
values2 = np.zeros(len(wl))
count = -1
for w in wl:
    count += 1
    for i in range(len(filtered2)):
        data2_noised = data2 + np.random.normal(0, 0.15, len(data2))
        filtered2[i,:] = medfilt(data2_noised, w)
        filtered_spikes[i,:] = data2 - filtered2[i,:]
        #rms[i] = np.sqrt(np.mean(np.square(filtered2[i,:])))
        rms[i] = np.sqrt(np.mean(np.square(filtered_spikes[i,:])))
    mean = np.mean(rms)
    values2[count] = mean

In [17]:
filtered3 = np.zeros((140, len(data3)))
filtered_spikes = np.zeros((140, len(data3)))
rms = np.zeros(140)
values3 = np.zeros(len(wl))
count = -1
for w in wl:
    count += 1
    for i in range(len(filtered3)):
        data3_noised = data3 + np.random.normal(0, 0.15, len(data3))
        filtered3[i,:] = medfilt(data3_noised, w)
        filtered_spikes[i,:] = data3 - filtered3[i,:]
        #rms[i] = np.sqrt(np.mean(np.square(filtered3[i,:])))
        rms[i] = np.sqrt(np.mean(np.square(filtered_spikes[i,:])))
    mean = np.mean(rms)
    values3[count] = mean

In [18]:
filtered4 = np.zeros((140, len(data4)))
filtered_spikes = np.zeros((140, len(data4)))
rms = np.zeros(140)
values4 = np.zeros(len(wl))
count = -1
for w in wl:
    count += 1
    for i in range(len(filtered4)):
        data4_noised = data4 + np.random.normal(0, 0.15, len(data4))
        filtered4[i,:] = medfilt(data4_noised, w)
        filtered_spikes[i,:] = data4 - filtered4[i,:]
        #rms[i] = np.sqrt(np.mean(np.square(filtered4[i,:])))
        rms[i] = np.sqrt(np.mean(np.square(filtered_spikes[i,:])))
    mean = np.mean(rms)
    values4[count] = mean

In [19]:
filtered5 = np.zeros((140, len(data5)))
filtered_spikes = np.zeros((140, len(data5)))
rms = np.zeros(140)
values5 = np.zeros(len(wl))
count = -1
for w in wl:
    count += 1
    for i in range(len(filtered5)):
        data5_noised = data5 + np.random.normal(0, 0.15, len(data5))
        filtered5[i,:] = medfilt(data5_noised, w)
        filtered_spikes[i,:] = data5 - filtered5[i,:]
        #rms[i] = np.sqrt(np.mean(np.square(filtered5[i,:])))
        rms[i] = np.sqrt(np.mean(np.square(filtered_spikes[i,:])))
    mean = np.mean(rms)
    values5[count] = mean

In [20]:
filtered6 = np.zeros((140, len(data6)))
filtered_spikes = np.zeros((140, len(data6)))
rms = np.zeros(140)
values6 = np.zeros(len(wl))
count = -1
for w in wl:
    count += 1
    for i in range(len(filtered6)):
        data6_noised = data6 + np.random.normal(0, 0.15, len(data6))
        filtered6[i,:] = medfilt(data6_noised, w)
        filtered_spikes[i,:] = data6 - filtered6[i,:]
        #rms[i] = np.sqrt(np.mean(np.square(filtered6[i,:])))
        rms[i] = np.sqrt(np.mean(np.square(filtered_spikes[i,:])))
    mean = np.mean(rms)
    values6[count] = mean

Calculate the RMS for each spike train


In [19]:
viridis_data = np.loadtxt('viridis_data.txt')

In [20]:
plt.plot(wl, values1, color = viridis_data[40,:])
plt.plot(wl, values2, color = viridis_data[80,:])
plt.plot(wl, values3, color = viridis_data[120,:])
plt.plot(wl, values4, color = viridis_data[160,:])
plt.plot(wl, values5, color = viridis_data[200,:])
plt.plot(wl, values6, color = viridis_data[240,:])
plt.xlabel('Window length',  fontsize = 15)
plt.ylabel('RMS',  fontsize = 15)


Out[20]:
<matplotlib.text.Text at 0x1748e128>

In [54]:
plt.plot(data3_noised, color = 'cornflowerblue')
filtered1 = medfilt(data3_noised, 15)
plt.axis((0, len(data1) , -0.7, 1.2))
plt.plot(filtered1, color = 'g', lw = 1.5 )
plt.plot(data3-filtered1, color = 'r', lw = 0.5)
plt.savefig("Spiketrain1wl10 sigmal= 0.1.png", dpi = 400)



In [21]:
count = -1
for w in wl:
    count += 1
    fig = plt.figure(count, figsize=(30,7))
    plt.axis((0, len(data1) , -0.7, 1.2))
    plt.plot(data1_noised, color = 'cornflowerblue')
    filtered1[count,:] = medfilt(data1_noised, w)
    plt.plot(data1-filtered1[count,:], color = 'r', lw = 0.5)
    plt.plot(filtered1[count,:], color = 'g', lw = 1.5)
    plt.xlabel('Window length ' + str(w), fontsize=25)



In [22]:
count = -1
for w in wl:
    count += 1
    fig = plt.figure(count, figsize=(30,7))
    plt.axis((0, len(data2), -0.5, 1.3))
    plt.plot(data2_noised, color = 'cornflowerblue')
    filtered2[count,:] = medfilt(data2_noised, w)
    plt.plot(data2-filtered2[count,:], color = 'r', lw = 0.5)
    plt.plot(filtered2[count,:], color = 'g', lw = 1.5)
    plt.xlabel('Window length ' + str(w), fontsize=25)



In [23]:
count = -1
for w in wl:
    count += 1
    fig = plt.figure(count, figsize=(30,7))
    plt.axis((0, len(data3), -0.5, 1.3))
    plt.plot(data3_noised, color = 'cornflowerblue')
    filtered3[count,:] = medfilt(data3_noised, w)
    plt.plot(data3-filtered3[count,:], color = 'r', lw = 0.5)
    plt.plot(filtered3[count,:], color = 'g', lw = 1.5)
    plt.xlabel('Window length ' + str(w), fontsize=25)



In [24]:
count = -1
for w in wl:
    count += 1
    fig = plt.figure(count, figsize=(30,7))
    plt.axis((0, len(data4) , -0.7, 1.2))
    plt.plot(data4_noised, color = 'cornflowerblue')
    filtered4[count,:] = medfilt(data4_noised, w)
    plt.plot(data4-filtered4[count,:], color = 'r', lw = 0.5)
    plt.plot(filtered4[count,:], color = 'g', lw = 1.5)
    plt.xlabel('Window length ' + str(w), fontsize=25)



In [25]:
count = -1
for w in wl:
    count += 1
    fig = plt.figure(count, figsize=(30,7))
    plt.axis((0, len(data5) , -0.7, 1.2))
    plt.plot(data5_noised, color = 'cornflowerblue')
    filtered5[count,:] = medfilt(data5_noised, w)
    plt.plot(data5-filtered5[count,:], color = 'r', lw = 0.5)
    plt.plot(filtered5[count,:], color = 'g', lw = 1.5)
    plt.xlabel('Window length ' + str(w), fontsize=25)



In [26]:
count = -1
for w in wl:
    count += 1
    fig = plt.figure(count, figsize=(30,7))
    plt.axis((0, len(data5) , -0.7, 1.2))
    plt.plot(data5_noised, color = 'cornflowerblue')
    filtered5[count,:] = medfilt(data5_noised, w)
    plt.plot(data5-filtered5[count,:], color = 'r', lw = 0.5)
    plt.plot(filtered5[count,:], color = 'g', lw = 1.5)
    plt.xlabel('Window length ' + str(w), fontsize=25)


Figures


In [27]:
plt.figure()
plt.axis([0.0, 40, 0, .5])
plt.plot(wl, values1, color = viridis_data[40,:])
plt.plot(wl, values2, color = viridis_data[80,:])
plt.plot(wl, values3, color = viridis_data[120,:])
plt.plot(wl, values4, color = viridis_data[160,:])
plt.plot(wl, values5, color = viridis_data[200,:])
plt.plot(wl, values6, color = viridis_data[240,:])
plt.xlabel('Window length',  fontsize = 15)
plt.ylabel('RMS',  fontsize = 15)
plt.savefig('RMS spiketrains.png', dpi = 300)



In [44]:
plt.figure()
plt.axis([0.0, 40, 0, .5])
plt.plot(wl, values1, color = viridis_data[40,:], label = "Peak-Peak = 10 samples")
#plt.plot(wl, values2, color = viridis_data[80,:])
plt.plot(wl, values3, color = viridis_data[100,:], label = "Peak-Peak = 20 samples")
#plt.plot(wl, values4, color = viridis_data[160,:])
plt.plot(wl, values5, color = viridis_data[160,:], label = "Peak-Peak = 30 samples")
plt.plot(wl, values6, color = viridis_data[220,:], label = "Peak-Peak = 35 samples")
plt.xlabel('Window length',  fontsize = 15)
plt.ylabel('RMS',  fontsize = 15)
plt.legend(loc='lower right',prop={'size':8})
plt.savefig('RMS spiketrains A.png', dpi = 300)



In [29]:
plt.figure()
plt.axis([0, 100, -0.7, 1.4])
plt.plot(data1_noised, color = 'cornflowerblue')
plt.plot(medfilt(data1_noised, 5), color = 'green', linewidth = 1.5)
plt.plot(data1 - medfilt(data1_noised, 5), color = 'red', linewidth = .5)
plt.savefig("Spiketrain dist 10, wl = 5.png", dpi = 300)



In [71]:
plt.figure()
plt.axis([0, 100, -0.7, 1.4])
plt.plot(data1_noised, color = 'cornflowerblue')
plt.plot(medfilt(data1_noised, 15), color = 'green', linewidth = 1.5)
plt.plot(data1 - medfilt(data1_noised, 15), color = 'red', linewidth = .5)
plt.savefig("Spiketrain dist 10, wl 15.png", dpi = 300)



In [72]:
plt.figure()
plt.axis([0, 100, -0.7, 1.4])
plt.plot(data1_noised, color = 'cornflowerblue')
plt.plot(medfilt(data1_noised, 19), color = 'green', linewidth = 1.5)
plt.plot(data1 - medfilt(data1_noised, 19), color = 'red', linewidth = .5)
plt.savefig("Spiketrain dist 10, wl 19.png", dpi = 300)



In [73]:
plt.figure()
plt.axis([0, 100, -0.7, 1.4])
plt.plot(data3_noised, color = 'cornflowerblue')
plt.plot(medfilt(data3_noised, 19), color = 'green', linewidth = 1.5)
plt.plot(data3 - medfilt(data3_noised, 19), color = 'red', linewidth = .5)
plt.savefig("Spiketrain dist 20, wl 19.png", dpi = 300)



In [78]:
plt.figure()
plt.axis([0, 100, -0.7, 1.4])
plt.plot(data5_noised, color = 'cornflowerblue')
plt.plot(medfilt(data5_noised, 11), color = 'green', linewidth = 1.5)
plt.plot(data5 - medfilt(data5_noised, 11), color = 'red', linewidth = .5)
plt.savefig("Spiketrain dist 30, wl 11.png", dpi = 300)



In [75]:
plt.figure()
plt.axis([0, 100, -0.7, 1.4])
plt.plot(data6_noised, color = 'cornflowerblue')
plt.plot(medfilt(data6_noised, 15), color = 'green', linewidth = 1.5)
plt.plot(data6 - medfilt(data6_noised, 15), color = 'red', linewidth = .5)
plt.savefig("Spiketrain dist 35, wl 15.png", dpi = 300)